import torch

from .base import BasePruner
from .utils import check_sparsity, find_layers


class StructuredPruner(BasePruner):
    def __init__(
        self,
        scores="weight",
        sparsity_ratio=0.0,
        n_samples=0,
        seed=0,
        dataset_name="c4",
        **kwargs
    ):
        super().__init__(
            scores, sparsity_ratio, n_samples, seed, dataset_name, **kwargs
        )
        self.prune_n = kwargs["prune_n"]
        self.prune_m = kwargs["prune_m"]

    def prune(self, model, tokenizer, device):
        if self.W_metrics is None:
            self.calculate_scores(model, tokenizer, device)
        cnt = 0
        try:
            layers = model.model.layers
        except:
            layers = model.model.decoder.layers
        for i in range(len(layers)):
            layer = layers[i]
            subset = find_layers(layer)

            for name in subset:
                W = subset[name].weight.data
                W_metric = self.W_metrics[cnt]
                W_mask = torch.zeros_like(W_metric) == 1
                for ii in range(W_metric.shape[1]):
                    if ii % self.prune_m == 0:
                        tmp = W_metric[:, ii : (ii + self.prune_m)].float()
                        W_mask.scatter_(
                            1,
                            ii + torch.topk(tmp, self.prune_n, dim=1, largest=False)[1],
                            True,
                        )
                subset[name].weight.data[W_mask] = 0
                del self.W_metrics[cnt]
                cnt += 1
        torch.cuda.empty_cache()

        return model, check_sparsity(model)


def get(**kwargs):
    return StructuredPruner(**kwargs)
